#include "config.h"

#include <dirent.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <ctype.h>
#include <errno.h>

#define DBG_SUBSYS S_LIBYLIB

#include "sysy_lib.h"
#include "ypage.h"
#include "msgqueue.h"

typedef struct {
        uint32_t roff;
} seg_meta_t;

typedef msgqueue_msg_t msg_t;

#define QUEUE_CHECK(__queue__)                                      \
        do {                                                        \
                uint32_t __i__;                                     \
                msgqueue_seg_t *__seg__;                            \
                for (__i__ = 0; __i__ <= __queue__->idx; __i__++) { \
                        __seg__ = &__queue__->seg[__i__];           \
                        if (__seg__->fd != -1)                      \
                                YASSERT(__seg__->roff < __seg__->woff || \
                                        (__seg__->roff == 0 && __seg__->woff == 0)); \
                }                                                   \
        } while (0);

static inline void __msgqueue_freeseg(msgqueue_seg_t *seg, const char *home, uint32_t idx);

static inline int __msgqueue_init(msgqueue_t *queue, const char *path,
                uint32_t msg_size, uint64_t msgqueue_size, int persistent)
{
        int ret, i;
        msgqueue_seg_t *seg;

        memset(queue, 0x0, sizeof(msgqueue_t));

        //不能超过最大值
        if (msgqueue_size > (MSGQUEUE_SEG_LEN * (LLU)MSGQUEUE_SEG_COUNT_MAX)) {
                ret = EINVAL;
                GOTO(err_ret, ret);
        }

        for (i = 0; i < MSGQUEUE_SEG_COUNT_MAX; i++) {
                seg = &queue->seg[i];

                seg->fd = -1;
                seg->woff = 0;
                seg->roff = 0;
        }

        snprintf(queue->home, MAX_PATH_LEN, "%s", path);
        queue->seg_len = MSGQUEUE_SEG_LEN;
        queue->msg_size = msg_size;
        queue->persistent = persistent;
        queue->lock_fd = -1;

        sy_rwlock_init(&queue->rwlock, "msgqueue.rwlock");

        return 0;
err_ret:
        return ret;
}

/**
 * 
 *
 * @Param queue
 * @Param path msgqueue的数据存放位置
 * @Param msg_size 单个msg长度
 * @Param msgqueue_size msgqueue的长度
 * @Param persistent 是否持久化
 *
 * @Returns   
 */
int msgqueue_init(msgqueue_t *queue, const char *path,
                uint32_t msg_size, uint64_t msgqueue_size, int persistent)
{
        int ret, fd;
        char lock[MAX_PATH_LEN];

        ret = path_validate(path, YLIB_ISDIR, YLIB_DIRCREATE);
        if (unlikely(ret))
                GOTO(err_ret, ret);

        snprintf(lock, MAX_PATH_LEN, "%s/lock", path);
        fd = daemon_lock(lock);
        if (fd < 0) {
                ret = -fd;
                GOTO(err_ret, ret);
        }

        ret = __msgqueue_init(queue, path, msg_size, msgqueue_size, persistent);
        if (unlikely(ret))
                GOTO(err_ret, ret);

        return 0;
err_ret:
        return ret;
}

static uint64_t  __msgqueue_total(msgqueue_t *queue)
{
        return (LLU)queue->idx * queue->seg_len
                + queue->seg[queue->idx].woff - queue->seg[0].roff;
}

/**
 * 确保msgqueue内的消息是queue->msg_size的整数倍
 *
 * @Param queue
 *
 * @Returns   
 */
static int  __msgqueue_truncate(msgqueue_t *queue)
{
        int idx;
        uint64_t total, total_new, offset;

        total = __msgqueue_total(queue);
        total_new = round_down(total, queue->msg_size); //向下对齐
        total_new = queue->seg[0].roff + total_new;

        idx = total_new / queue->seg_len; //第几个seg
        offset = total_new % queue->seg_len; //seg内偏移

        if (idx != queue->idx) {
                DWARN("truncate queue->idx from %d to %d\n", queue->idx, idx);
                YASSERT(idx == queue->idx - 1);
                __msgqueue_freeseg(&queue->seg[queue->idx], queue->home, queue->idx);
                queue->idx = idx;
        }
        
        if (queue->seg[idx].woff != offset) {
                DWARN("queue->seg[%d] roff=%llu, woff=%llu\n",
                                queue->idx, (LLU)queue->seg[idx].roff, (LLU)queue->seg[idx].woff);
                DWARN("truncate seg[%d] woff to %llu\n",
                                queue->idx, (LLU)offset);
                queue->seg[idx].woff = offset;
        }

        total = __msgqueue_total(queue);
        YASSERT(total % queue->msg_size == 0);

        return 0;
}

int __msgqueue_load(msgqueue_t *queue, const char *path,
                uint32_t msg_len, uint64_t msgqueue_size, int persistent)
{
        int ret, count, max, i, j;
        DIR *dir;
        struct dirent *de;
        char path1[MAX_PATH_LEN], path2[MAX_PATH_LEN], path_err[MAX_PATH_LEN];
        seg_meta_t seg_meta;
        struct stat stbuf;
        msgqueue_seg_t *seg;

        ret = __msgqueue_init(queue, path, msg_len, msgqueue_size, persistent);
        if (unlikely(ret))
                GOTO(err_ret, ret);

retry:
        dir = opendir(queue->home);
        if (dir == NULL) {
                ret = errno;

                if (ret == ENOENT) {
                        ret = mkdir(queue->home, 0777);
                        if (unlikely(ret)) {
                                ret = errno;
                                GOTO(err_ret, ret);
                        }

                        goto out;
                } else
                        GOTO(err_ret, ret);
        }

        count = 0;
        max = 0;
        while ((de = readdir(dir)) != NULL) {
                if (isdigit(de->d_name[0])) {
                        if (atoi(de->d_name) > max)
                                max = atoi(de->d_name);

                        count++;
                }
        }

        (void) closedir(dir);

        j = 0;
        for (i = 0; i < count; i++) {
                while (1) {
                        YASSERT(j <= max);

                        snprintf(path2, MAX_PATH_LEN, "%s/%d", queue->home, j);

                        ret = stat(path2, &stbuf);
                        if (ret < 0) {
                                ret = errno;
                                if (ret == ENOENT) {
                                        j++;
                                        continue;
                                }

                                GOTO(err_ret, ret);
                        }

                        if ((LLU)stbuf.st_size <= (LLU)sizeof(seg_meta_t)) {
                                snprintf(path_err, MAX_PATH_LEN, "%s/error.%d", queue->home, j);

                                DERROR("invalid file %s, move to %s\n", path2, path_err);

                                rename(path2, path_err);

                                goto retry;
                        }

                        if ((size_t)stbuf.st_size != queue->seg_len + sizeof(seg_meta_t)
                                && (i != count -1)) {
                                DERROR("invalid file %s/%d\n", queue->home, j);
                                YASSERT(i == count - 2); //断电后，倒数第二个文件也有可能不完整
                                snprintf(path_err, MAX_PATH_LEN, "%s/%d", queue->home, j+1);
                                unlink(path_err); //删除倒数第一个文件, 并重试
                                goto retry;
                        }

                        break;
                }

                if (i != j) {
                        snprintf(path1, MAX_PATH_LEN, "%s/%d", queue->home, i);

                        DWARN("rname from %s to %s\n", path2, path1);

                        ret = rename(path2, path1);
                        if (ret < 0) {
                                ret = errno;
                                GOTO(err_ret, ret);
                        }
                }

                j++;
        }

        for (i = 0; i < count; i++) {
                seg = &queue->seg[i];

                snprintf(path1, MAX_PATH_LEN, "%s/%d", queue->home, i);

                ret = open(path1, O_RDWR);
                if (ret < 0) {
                        ret = errno;
                        GOTO(err_ret, ret);
                }

                seg->fd = ret;

                if (i == 0) {
                        ret = _pread(seg->fd, &seg_meta, sizeof(seg_meta_t), 0);
                        if (ret < 0) {
                                ret = -ret;
                                GOTO(err_ret, ret);
                        }

                        seg->roff = seg_meta.roff;
                }

                if (i == count - 1) {
                        ret = fstat(seg->fd, &stbuf);
                        if (ret < 0) {
                                ret = errno;
                                GOTO(err_ret, ret);
                        }

                        seg->woff = stbuf.st_size - sizeof(seg_meta_t);
                } else
                        seg->woff = queue->seg_len;

                queue->idx = i;
        }

out:
        QUEUE_CHECK(queue);

        return 0;
err_ret:
        return ret;
}

/**
 * 加载msgqueue
 *
 * @Param queue
 * @Param path
 * @Param msg_size 单个msg长度, 注意包含了sizeof(msg_t)
 * @Param msgqueue_size 
 * @Param persistent
 *
 * @Returns   
 */
int msgqueue_load(msgqueue_t *queue, const char *path,
                uint32_t msg_size, uint64_t msgqueue_size, int persistent)
{
        int ret, fd;
        char lock[MAX_PATH_LEN];

        snprintf(lock, MAX_PATH_LEN, "%s/lock", path);
        fd = daemon_lock(lock);
        if (fd < 0) {
                ret = -fd;
                GOTO(err_ret, ret);
        }

        ret = __msgqueue_load(queue, path, msg_size, msgqueue_size, persistent);
        if (ret) {
                GOTO(err_close, ret);
        }

        ret = __msgqueue_truncate(queue);
        if (ret) {
                GOTO(err_close, ret);
        }

        queue->lock_fd = fd;

        return 0;
err_close:
        close(fd);
err_ret:
        return ret;
}

static int __msgqueue_newseg(msgqueue_seg_t *seg, const char *home, uint32_t idx, int persistent)
{
        int ret; 
        char path[MAX_PATH_LEN];

        YASSERT(seg->fd == -1);

        snprintf(path, MAX_PATH_LEN, "%s/%u", home, idx);

        DINFO("open file %s\n", path);

        ret = open(path, O_RDWR | O_CREAT, 0644);
        if (ret < 0) {
                ret = errno;
                GOTO(err_ret, ret);
        }

        seg->fd = ret;
        seg->woff = 0;
        seg->roff = 0;

        if (persistent == 0) {
                unlink(path);
        }

        return 0;
err_ret:
        return ret;
}

void *__msgqueue_syncseg__(void *_arg)
{
        msgqueue_seg_t *seg = _arg;

        if (seg->fd > 0) {
                fsync(seg->fd);
        }

        return NULL;
}

/**
 * 异步sync fd
 *
 * @Param seg
 *
 * @Returns   
 */
int __msgqueue_syncseg(const msgqueue_seg_t *seg)
{
        int ret;
        pthread_t th;
        pthread_attr_t ta;

        (void) pthread_attr_init(&ta);
        (void) pthread_attr_setdetachstate(&ta, PTHREAD_CREATE_DETACHED);

        ret = pthread_create(&th, &ta, __msgqueue_syncseg__, (void *)seg);
        if (unlikely(ret))
                GOTO(err_ret, ret);

        return 0;
err_ret:
        return ret;
}

/**
 * 
 *
 * @Param queue
 * @Param _msg
 * @Param len
 *
 * @Returns 返回值小于0，表示出错, 返回0表示成功
 */
int msgqueue_push(msgqueue_t *queue, const void *_msg, uint32_t len)
{
        int ret;
        msgqueue_seg_t *seg;
        uint32_t left, cp, msglen;
        msg_t *msg;
        char buf[MAX_BUF_LEN];

        YASSERT(len <= MAX_BUF_LEN - sizeof(msg_t));

        if (queue->idx >= MSGQUEUE_SEG_COUNT_MAX) {
                ret = EINVAL;
                GOTO(err_ret, ret);
        }

        ret = sy_rwlock_wrlock(&queue->rwlock);
        if (unlikely(ret))
                GOTO(err_ret, ret);

        QUEUE_CHECK(queue);

        seg = &queue->seg[queue->idx];

        if (seg->fd == -1) {
                ret = __msgqueue_newseg(seg, queue->home, queue->idx, queue->persistent);
                if (unlikely(ret))
                        UNIMPLEMENTED(__DUMP__);
        }

        msglen = sizeof(msg_t) + len;

        if (queue->idx == MSGQUEUE_SEG_COUNT_MAX - 1
            && seg->woff + msglen > queue->seg_len) {
                ret = ENOSPC;
                GOTO(err_lock, ret);
        }

        msg = (void *)buf;

        msg->len = len;
        msg->crc = crc32_sum(_msg, len);
        memcpy(msg->buf, _msg, len);

        left = queue->seg_len - seg->woff;

        cp = left < msglen ? left : msglen;

        ret = _pwrite(seg->fd, buf, cp, seg->woff + sizeof(seg_meta_t));
        if (ret < 0) {
                ret = -ret;
                GOTO(err_lock, ret);
        }

        seg->woff += cp;

        if (seg->woff == queue->seg_len) {
                __msgqueue_syncseg(seg); // 需要异步, 在core中运行，不能阻塞
                queue->idx++;
                seg = &queue->seg[queue->idx];
                ret = __msgqueue_newseg(seg, queue->home, queue->idx, queue->persistent);
                if (unlikely(ret))
                        UNIMPLEMENTED(__DUMP__);

                if (msglen > cp) {
                        ret = _pwrite(seg->fd, buf + cp, msglen - cp,
                                      seg->woff + sizeof(seg_meta_t));
                        if (ret < 0) {
                                ret = -ret;
                                GOTO(err_lock, ret);
                        }

                        seg->woff += msglen - cp;
                }
        }

        QUEUE_CHECK(queue);

        sy_rwlock_unlock(&queue->rwlock);

        return 0;
err_lock:
        sy_rwlock_unlock(&queue->rwlock);
err_ret:
        return -ret;
}

static inline void __msgqueue_freeseg(msgqueue_seg_t *seg, const char *home, uint32_t idx)
{
        char path[MAX_PATH_LEN];

        snprintf(path, MAX_PATH_LEN, "%s/%u", home, idx);

        close(seg->fd);
        unlink(path);

        seg->fd = -1;
        seg->woff = 0;
        seg->roff = 0;
}

static inline int __msgqueue_get(msgqueue_t *queue, void *msg, uint32_t *len)
{
        int ret;
        uint32_t left, cp, msglen, i, buflen;
        uint64_t total;
        msgqueue_seg_t *seg;

        total = __msgqueue_total(queue);
        if (total == 0) {
                *len = 0;
                goto out;
        }

        left = *len;
        buflen = *len;

        for (i = 0; i <= queue->idx; i++) {
                seg = &queue->seg[i];
                msglen = seg->woff - seg->roff;
                cp = left < msglen ? left : msglen;

                ret = _pread(seg->fd, msg + (buflen - left), cp,
                             seg->roff + sizeof(seg_meta_t));
                if (ret < 0) {
                        ret = -ret;
                        GOTO(err_ret, ret);
                }

                left -= cp;

                if (left == 0)
                        break;
        }

        *len = buflen - left;

out:
        return 0;
err_ret:
        return ret;
}

int msgqueue_get(msgqueue_t *queue, void *msg, uint32_t len)
{
        int ret;
        uint32_t buflen;

        buflen = len;

        ret = sy_rwlock_wrlock(&queue->rwlock);
        if (unlikely(ret))
                GOTO(err_ret, ret);

        QUEUE_CHECK(queue);

        ret = __msgqueue_get(queue, msg, &buflen);
        if (unlikely(ret))
                GOTO(err_lock, ret);

        QUEUE_CHECK(queue);

        sy_rwlock_unlock(&queue->rwlock);

        return buflen;
err_lock:
        sy_rwlock_unlock(&queue->rwlock);
err_ret:
        return -ret;
}

/**
 * 如果 msg 为NULL，就直接删除长度为 len 的消息。
 * 如果 msg 不为NULL，就先保存长度为 len 的消息到msg中, 再删除。
 *
 * @Param queue
 * @Param msg
 * @Param len 要pop的消息长度
 *
 * @Returns   
 */
int msgqueue_pop(msgqueue_t *queue, void *msg, uint32_t len)
{
        int ret;
        uint32_t count, i, off0, buflen;
        uint64_t total;
        msgqueue_seg_t *seg;
        char path1[MAX_PATH_LEN], path2[MAX_PATH_LEN];
        seg_meta_t seg_meta;

        ret = sy_rwlock_wrlock(&queue->rwlock);
        if (unlikely(ret))
                GOTO(err_ret, ret);

        QUEUE_CHECK(queue);

        //每次pop，不超过seg_len
        if (len > queue->seg_len ) {
                ret = EINVAL;
                GOTO(err_lock, ret);
        }

        buflen = len;

        if (msg) {
                ret = __msgqueue_get(queue, msg, &buflen);
                if (unlikely(ret))
                        GOTO(err_lock, ret);
        }

        total = __msgqueue_total(queue);
        if (total == 0)
                goto out;

        buflen = total < buflen ? total : buflen;

        count = (buflen + queue->seg[0].roff) / queue->seg_len;

        if (count) {
                off0 = queue->seg[0].roff;

                for (i = 0; i < count; i++) {
                        seg = &queue->seg[i];
                        __msgqueue_freeseg(seg, queue->home, i);
                }

                for (i = count; i <= queue->idx; i++) {
                        seg = &queue->seg[i];
                        queue->seg[i - count] = *seg;

                        snprintf(path1, MAX_PATH_LEN, "%s/%u", queue->home, i);
                        snprintf(path2, MAX_PATH_LEN, "%s/%u", queue->home, i - count);

                        /*printf("rename path %s -> %s\n", path1, path2);*/
                        ret = rename(path1, path2);
                        if (ret < 0) {
                                ret = errno;
                                GOTO(err_lock, ret);
                        }

                        seg->fd = -1;
                        seg->woff = 0;
                        seg->roff = 0;
                }

                queue->idx -= count;

                //queue->seg[0].roff = (off0 + buflen - count * queue->seg_len);
                queue->seg[0].roff = (off0 + buflen) % queue->seg_len;
        } else
                queue->seg[0].roff += buflen;

        if (queue->seg[0].roff == queue->seg[0].woff) { //last seg
                YASSERT(queue->idx == 0);

                __msgqueue_freeseg(&queue->seg[0], queue->home, 0);
        }

        QUEUE_CHECK(queue);

        if (queue->seg[0].fd != -1) {
                seg_meta.roff = queue->seg[0].roff;

                ret = _pwrite(queue->seg[0].fd, &seg_meta, sizeof(seg_meta_t), 0);
                if (ret < 0) {
                        ret = -ret;
                        GOTO(err_lock, ret);
                }
        }

        QUEUE_CHECK(queue);

out:
        sy_rwlock_unlock(&queue->rwlock);

        return buflen;
err_lock:
        sy_rwlock_unlock(&queue->rwlock);
err_ret:
        return -ret;
}

int msgqueue_empty(msgqueue_t *queue)
{
        uint64_t total;

        total = __msgqueue_total(queue);
        return !total;
}

void msgqueue_close(msgqueue_t *queue)
{
        int i;
        msgqueue_seg_t *seg;

        close(queue->lock_fd);

        for (i = 0; i < MSGQUEUE_SEG_COUNT_MAX; i++) {
                seg = &queue->seg[i];
                if (seg->fd != -1) {
                        close(seg->fd);
                }
        }
}
